## These functions create the subgroup and salience comparison plots
## Adapted from:
# Clayton, Katherine, Jeremy Ferwerda, and Yusaku Horiuchi. 
##"Exposure to immigration and admission preferences: Evidence from France." Political Behavior 43.1 (2021): 175-200.

set_my_ggtheme <- function(base_size = 10, base_family = "") { 
  
  theme_few(base_size = base_size, 
            base_family = base_family) %+replace%
    theme(axis.text.x =  element_text(size = base_size, colour = "black", hjust = .5 , vjust = 1),
          axis.text.y =  element_text(size = base_size, colour = "black", hjust = 0  , vjust = 0.5),
          axis.ticks =   element_line(colour = "grey50"),
          axis.title.y = element_text(size = base_size, angle = 90, vjust = .01, hjust = .1),
          legend.position = "top")
  
}


order_attributes_levels <- function(attribute, level) {
  
  attribute <- factor(attribute, levels = Attribute_order)
  level <- factor(level, levels = Level_order)
  level <- as.numeric(level)
  order1 <- as.numeric(attribute) * 1000 + level
  order2 <- rank(-order1)
  order3 <- -as.numeric(attribute) * 1000 + order2
  rank(order3)
  
}

#options(stringsAsFactors = FALSE)

vis_results <- function(mod, mod_labels, wave=2, fig=1, alpha = 0.05, bs_n = 1000){

  Attribute_order_w2 <- c("Health.Status",
                          "Age",
                          "Occupation",
                          "Gender",
                          "Education",
                          "Economic.Status",
                          "Place.of.Birth",
                          "Political.ID")
  
  Level_order_w2 <- c(c("Disabled","Chronic condition","Good health","Excellent health"),
                      c("Under 30","31-49","50-64","65-79","Over 80"),
                      c("Essential worker","Law Enforcement","Nurse","Homeworker","Unemployed"),
                      c("Woman","Man"),
                      c("Lower Secondary","Vocational","College preparatory","Bachelors"),
                      c("Poor","Middle class","Well off","Millionaire"), 
                      c("Northern Italy","Central Italy","Southern Italy","Spain","Morocco","Nigeria"),
                      c("Partito Democratico","Movimento 5 Stelle","Forza Italia","Lega","Fratelli d'Italia"))
  
  
  Attribute_order_w1 <- c("Health.Status",
                          "Age",
                          "Gender",
                          "Education",
                          "Economic.Status",
                          "Place.of.Birth",
                          "Political.ID")
  
  Level_order_w1 <- c(c("Poor health","Good health","Excellent health"),
                      c("Under 30","31-45","46-60","Over 60"),
                      c("Uomo","Donna"),
                      c("Lower Secondary","Vocational","College preparatory","Bachelors"),
                      c("Poor","Middle class","Well off","Millionaire"),  
                      c("Northern Italy","Central Italy","Southern Italy","Spain","Netherlands","China","Morocco","Nigeria"),
                      c("Partito Democratico","Movimento 5 Stelle","Forza Italia","Lega","Fratelli d'Italia"))
  
  Attribute_order <- Attribute_order_w2
  Level_order <- Level_order_w2
  
  
  if (wave==1){
    edge <- 13
    Attribute_order <- Attribute_order_w1
    Level_order <- Level_order_w1
    data <- cdata_w1
    title = "Health Guideline Violations"
    
  } else {
    edge <- 14
    Attribute_order <- Attribute_order_w2
    Level_order <- Level_order_w2
    data <- cdata
    title = "Vaccine Priority"
  }
 
  ## Count N for each subgroup and change labels
  check_n <- data %>%
    select(respondent, .data[[mod]]) %>% 
    distinct() %>% 
    count(.data[[mod]])
  
  
  n0 <- check_n %>% filter(.data[[mod]]== 0) %>% select(n) %>% pull()
  n1 <- check_n %>% filter(.data[[mod]]== 1) %>% select(n) %>% pull()
  mod_labels2 <- mod_labels
  mod_labels <- c(paste0(mod_labels[1], "\n(N = ", n0, ")"),
                  paste0(mod_labels[2], "\n(N = ", n1, ")"))
  
  
  ## Reshape data and nest by attribute-levels
  
  df <- data %>% select(-respondentIndex) %>%
    select(1:!!edge, "moderator" = .data[[mod]]) %>%
    select(-PID,-task,-Response.ID) %>%
    mutate_if(is.factor, as.character) %>% 
    relocate(selected, .after=profile)  %>% 
    relocate(moderator, .after=profile)  %>% 
    gather("attribute", "level", 5:ncol(.)) %>% 
    mutate(attribute_level = paste(attribute, level, sep = ": "))
  
  
  if (fig==1){
  ### First figure (Difference) #########################################################
    
  # To be merged for visualization
  attribute_levels_only <- df %>% 
    select(attribute_level, attribute, level) %>% 
    distinct()
  
  ## Marginal means for Group 1
  
  m0 <- df %>% 
    nest(data=-attribute_level) %>% 
    mutate(result = map(data, ~ lm_robust(selected ~ 1, data = .x %>% filter(moderator == 0), cluster = respondent, se_type = "stata")),
           tidied = map(result, tidy)) %>% 
    unnest(tidied) %>% 
    as_tibble() %>% 
    select(attribute_level, estimate, p.value, conf.low, conf.high) %>% 
    mutate(panel = mod_labels[1])
  
  ## Marginal means for Group 2
  
  m1 <- df %>% 
    nest(data=-attribute_level) %>% 
    mutate(result = map(data, ~ lm_robust(selected ~ 1, data = .x %>% filter(moderator == 1), cluster = respondent, se_type = "stata")),
           tidied = map(result, tidy)) %>% 
    unnest(tidied) %>% 
    as_tibble() %>% 
    select(attribute_level, estimate, p.value, conf.low, conf.high) %>% 
    mutate(panel = mod_labels[2])
  
  ## Difference in the marginal means
  
  diff <- df %>% 
    nest(data=-attribute_level) %>% 
    mutate(result = map(data, ~ lm_robust(selected ~ moderator, data = .x, cluster = respondent, se_type = "stata")),
           tidied = map(result, tidy)) %>% 
    unnest(tidied) %>% 
    as_tibble() %>%
    filter(str_detect(term, "moderator")) %>% 
    select(attribute_level, estimate, p.value, conf.low, conf.high) %>% 
    mutate(panel = "Difference")
  
  ## Merge all the results and do further wrangling

  attribute_only <- data.frame(
    attribute = rep(Attribute_order, 3),
    panel = c(rep(mod_labels[1], length(Attribute_order)),
              rep(mod_labels[2], length(Attribute_order)),
              rep("Difference", length(Attribute_order))),
    stringsAsFactors = FALSE)

  results1 <- bind_rows(m0, m1, diff) %>% 
    left_join(attribute_levels_only, by = "attribute_level") %>% 
    bind_rows(attribute_only) %>% 
    group_by(panel) %>% 
    mutate(order = order_attributes_levels(attribute, level)) %>% 
    ungroup() %>% 
    arrange(order) %>%
    mutate(var.names = paste0(attribute, ":"),
           var.names = ifelse(!is.na(level), paste0("   ", level), var.names)) %>% 
    mutate(color = 1, 
           color = ifelse(panel %in% c(mod_labels[1], mod_labels[2]), 2, color),
           color = ifelse(panel == "Difference" & conf.low > 0 & conf.high > 0, 3, color),
           color = ifelse(panel == "Difference" & conf.low < 0 & conf.high < 0, 3, color)) %>% 
    mutate(panel = factor(panel, 
                          levels = c(mod_labels[2], mod_labels[1], "Difference")))
  
  ## Visualize differences
  
  g1 <- ggplot() +
    geom_hline(data = results1 %>% filter(panel == "Difference"), 
               mapping = aes(yintercept = 0),
               color = "gray80",
               linetype = "dashed") +
    geom_hline(data = results1 %>% filter(panel != "Difference"), 
               mapping = aes(yintercept = 0.5),
               color = "gray80",
               linetype = "dashed") +
    geom_pointrange(data = results1, 
                    mapping = aes(x = reorder(var.names, order),
                                  y = estimate,
                                  ymax = conf.high,
                                  ymin = conf.low, 
                                  color = factor(color)),
                    na.rm = TRUE) +
    facet_wrap(~ panel, scales = "free_x") +
    scale_color_manual(values = c("darkgray", "darkgray", "black")) +
    guides(color = FALSE) +
    coord_flip() +
    labs(y = "\nExpected value of outcome variable",
         x = NULL) + set_my_ggtheme() + ggtitle(title)
  
    print(g1)
  
  
  } else {
  ### Second figure (Salience) #########################################################
  
  
  ## Bootstrapping by clusters (i.e., respondents)
  
  unique_id <- df %>% 
    select(respondent) %>% 
    distinct() %>% 
    pull()
  
  df_bootstrapped <- data.frame(
    draws = rep(1:bs_n, length(unique_id)), 
    respondent = sample(unique_id, length(unique_id) * bs_n, replace = TRUE) 
  ) %>% 
    left_join(df, by = "respondent")
  
  ## Calculate the difference in salience scores
  
  results2b <- df_bootstrapped %>%
    group_by(draws, attribute, level, moderator) %>% 
    summarise(average = mean(selected)) %>%
    ungroup() %>% 
    mutate(salience = abs(average - 0.5)) %>% 
    group_by(draws, moderator, attribute) %>% 
    summarise(average_salience = mean(salience)) %>% 
    ungroup() %>% 
    spread(moderator, average_salience) %>% 
    mutate(diff = (`1` - `0`)) %>% 
    group_by(attribute) %>% 
    summarise(difference_mean = mean(diff),
              difference_bs.low = quantile(diff, alpha / 2),
              difference_bs.high = quantile(diff, 1 - alpha / 2)) %>%
    ungroup() %>% 
    mutate(sig = ifelse((difference_bs.low > 0 & difference_bs.high > 0) | 
                          (difference_bs.low < 0 & difference_bs.high < 0), 1, 0))
  
  ## Caculate salience scores
  results2a <- df %>% 
    group_by(attribute, level, moderator) %>% 
    summarise(average = mean(selected)) %>%
    ungroup() %>% 
    mutate(deviation = abs(average - 0.5)) %>% 
    group_by(attribute, moderator) %>% 
    summarise(salience = mean(deviation)) %>% 
    ungroup() %>% 
    spread(moderator, salience) %>% 
    left_join(results2b %>% select(attribute, sig), by = "attribute")
  
  # Visualize attribute salience
  
  g2a <- ggplot(data = results2a,
                aes(x = `0`,
                    y = `1`)) +
    geom_abline(slope = 1, color = "gray80") +
    geom_point(aes(color = as.factor(sig))) +
    geom_text_repel(aes(label = attribute,
                        color = as.factor(sig))) +
    coord_equal(ylim = c(0, 0.11),
                xlim = c(0, 0.11)) +
    scale_x_continuous(breaks =  c(0, 0.05, 0.10)) +
    scale_y_continuous(breaks = c(0, 0.05, 0.10)) +
    scale_color_manual(values = c("darkgray", "black")) +
    guides(color = FALSE) +
    theme_few() +
    labs(x = mod_labels[1],
         y = mod_labels[2])  + ggtitle(paste0(title,": ", mod_labels2[2], " - ",mod_labels2[1])) + 
    theme(plot.title = element_text(size = 9))
  
  g2b <- ggplot(data = results2b) +
    geom_hline(yintercept = 0, 
               color = "gray70",
               linetype = "dashed") +
    geom_pointrange(aes(x = reorder(attribute, difference_mean),
                        y = difference_mean,
                        ymin = difference_bs.low,
                        ymax = difference_bs.high, 
                        color = as.factor(sig))) + 
    coord_flip() +
    scale_color_manual(values = c("darkgray", "black")) +
    guides(color = FALSE) +
    labs(y = "Difference in attribute salience", 
         x = NULL) +
    theme_few() + ggtitle(paste0(title,": ", mod_labels2[2], " - ",mod_labels2[1])) + 
    theme(plot.title = element_text(size = 9))
  
  ggarrange(g2a, g2b, nrow=1,ncol=2)
  }
} 
  